{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The watermark extension is already loaded. To reload it, use:\n",
      "  %reload_ext watermark\n",
      "Author: Sebastian Raschka\n",
      "\n",
      "Python implementation: CPython\n",
      "Python version       : 3.8.8\n",
      "IPython version      : 7.30.1\n",
      "\n",
      "torch: 1.10.1\n",
      "numpy: 1.22.2\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch,numpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Using PyTorch Dataset Loading Utilities for Custom Datasets (Images from Quickdraw)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook provides an example for how to load an image dataset, stored as individual PNG files, using PyTorch's data loading utilities. For a more in-depth discussion, please see the official\n",
    "\n",
    "- [Data Loading and Processing Tutorial](http://pytorch.org/tutorials/beginner/data_loading_tutorial.html)\n",
    "- [torch.utils.data](http://pytorch.org/docs/master/data.html) API documentation\n",
    "\n",
    "In this example, we are using the Quickdraw dataset consisting of handdrawn objects, which is available at https://quickdraw.withgoogle.com. \n",
    "\n",
    "To execute the following examples, you need to download the \".npy\" (bitmap files in NumPy). You don't need to download all of the 345 categories but only a subset you are interested in. The groups/subsets can be individually downloaded from https://console.cloud.google.com/storage/browser/quickdraw_dataset/full/numpy_bitmap\n",
    "\n",
    "Unfortunately, the Google cloud storage currently does not support selecting and downloading multiple groups at once. Thus, in order to download all groups most coneniently, we need to use their `gsutil` (https://cloud.google.com/storage/docs/gsutil_install) tool. If you want to install that, you can then use \n",
    "\n",
    "    mkdir quickdraw-npy\n",
    "    gsutil -m cp gs://quickdraw_dataset/full/numpy_bitmap/*.npy quickdraw-npy\n",
    "\n",
    "Note that if you download the whole dataset, this will take up 37 Gb of storage space.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After downloading the dataset to a local directory, `quickdraw-npy`, the next step is to select certain groups we are interested in analyzing. Let's say we are interested in the following groups defined in the `label_dict` in the next code cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_dict = {\n",
    "         \"lollipop\": 0,\n",
    "         \"binoculars\": 1,\n",
    "         \"mouse\": 2,\n",
    "         \"basket\": 3,\n",
    "         \"penguin\": 4,\n",
    "         \"washing machine\": 5,\n",
    "         \"canoe\": 6,\n",
    "         \"eyeglasses\": 7,\n",
    "         \"beach\": 8,\n",
    "         \"screwdriver\": 9,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dictionary values shall represent class labels that we could use for a classification task, for example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conversion to PNG files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we are going to convert the groups we are interested in (specified in the dictionary above) to individual PNG files using a helper function (note that this might take a while):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load utilities from ../helper.py\n",
    "import sys\n",
    "sys.path.insert(0, '..') \n",
    "from helper import quickdraw_npy_to_imagefile\n",
    "\n",
    "    \n",
    "quickdraw_npy_to_imagefile(inpath='quickdraw-npy',\n",
    "                           outpath='quickdraw-png_set1',\n",
    "                           subset=label_dict.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing into train/valid/test subsets and creating a label files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For convenience, let's create a CSV file mapping file names to class labels. First, let's collect the files and labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num paths: 1515745\n",
      "Num labels: 1515745\n"
     ]
    }
   ],
   "source": [
    "paths, labels = [], []\n",
    "\n",
    "main_dir = 'quickdraw-png_set1/'\n",
    "\n",
    "for d in os.listdir(main_dir):\n",
    "    subdir = os.path.join(main_dir, d)\n",
    "    if not os.path.isdir(subdir):\n",
    "        continue\n",
    "    for f in os.listdir(subdir):\n",
    "        path = os.path.join(d, f)\n",
    "        paths.append(path)\n",
    "        labels.append(label_dict[d])\n",
    "        \n",
    "print('Num paths:', len(paths))\n",
    "print('Num labels:', len(labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we shuffle the dataset and assign 70% of the dataset for training, 10% for validation, and 20% for testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mlxtend.preprocessing import shuffle_arrays_unison\n",
    "\n",
    "\n",
    "paths2, labels2 = shuffle_arrays_unison(arrays=[np.array(paths), np.array(labels)], random_seed=3)\n",
    "\n",
    "\n",
    "cut1 = int(len(paths)*0.7)\n",
    "cut2 = int(len(paths)*0.8)\n",
    "\n",
    "paths_train, labels_train = paths2[:cut1], labels2[:cut1]\n",
    "paths_valid, labels_valid = paths2[cut1:cut2], labels2[cut1:cut2]\n",
    "paths_test, labels_test = paths2[cut2:], labels2[cut2:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let us create a CSV file that maps the file paths to the class labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Label</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Path</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>penguin/penguin_112865.png</th>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>binoculars/binoculars_040058.png</th>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>eyeglasses/eyeglasses_208525.png</th>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>basket/basket_050889.png</th>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lollipop/lollipop_037650.png</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                  Label\n",
       "Path                                   \n",
       "penguin/penguin_112865.png            4\n",
       "binoculars/binoculars_040058.png      1\n",
       "eyeglasses/eyeglasses_208525.png      7\n",
       "basket/basket_050889.png              3\n",
       "lollipop/lollipop_037650.png          0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame(\n",
    "    {'Path': paths_train,\n",
    "     'Label': labels_train,\n",
    "    })\n",
    "\n",
    "df = df.set_index('Path')\n",
    "df.to_csv('quickdraw_png_set1_train.csv')\n",
    "\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Label</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Path</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>basket/basket_046788.png</th>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>basket/basket_074810.png</th>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>binoculars/binoculars_077567.png</th>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>screwdriver/screwdriver_007042.png</th>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>binoculars/binoculars_037327.png</th>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    Label\n",
       "Path                                     \n",
       "basket/basket_046788.png                3\n",
       "basket/basket_074810.png                3\n",
       "binoculars/binoculars_077567.png        1\n",
       "screwdriver/screwdriver_007042.png      9\n",
       "binoculars/binoculars_037327.png        1"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame(\n",
    "    {'Path': paths_valid,\n",
    "     'Label': labels_valid,\n",
    "    })\n",
    "\n",
    "df = df.set_index('Path')\n",
    "df.to_csv('quickdraw_png_set1_valid.csv')\n",
    "\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Label</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Path</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>eyeglasses/eyeglasses_051135.png</th>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mouse/mouse_059179.png</th>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>basket/basket_046866.png</th>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>penguin/penguin_043578.png</th>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>screwdriver/screwdriver_081750.png</th>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    Label\n",
       "Path                                     \n",
       "eyeglasses/eyeglasses_051135.png        7\n",
       "mouse/mouse_059179.png                  2\n",
       "basket/basket_046866.png                3\n",
       "penguin/penguin_043578.png              4\n",
       "screwdriver/screwdriver_081750.png      9"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame(\n",
    "    {'Path': paths_test,\n",
    "     'Label': labels_test,\n",
    "    })\n",
    "\n",
    "df = df.set_index('Path')\n",
    "df.to_csv('quickdraw_png_set1_test.csv')\n",
    "\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's open one of the images to make sure they look ok:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28, 28)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQfklEQVR4nO3df4xV5Z3H8c9XyoCiRFkGllCVohDxB0vhBoluRK1bhcQoJNViaFghC/4imhiyin9UQ2JkFWujUKWCpcJSSgRFo7slKDH1D8JVQJmdgIhjSyEwKLEYjTDDd/+YazrinOcM59xf4/N+JZM7c7/3ueeby3w4d+5zznnM3QXg+++0WjcAoDoIOxAJwg5EgrADkSDsQCR+UM2NDRw40IcNG1bNTQJRaWlp0eHDh62rWq6wm9kNkn4tqZek5939sdDjhw0bpmKxmGeTAAIKhUJiLfPbeDPrJWmxpEmSLpY0zcwuzvp8ACorz9/s4yXtcfe97n5M0h8k3VSetgCUW56wD5X0104/7yvd9y1mNtvMimZWbG1tzbE5AHnkCXtXHwJ859hbd1/q7gV3LzQ2NubYHIA88oR9n6RzO/38Q0n787UDoFLyhH2rpBFm9iMza5D0c0kbytMWgHLLPPXm7m1mdo+k/1XH1Ntyd28qW2dVdvTo0WD9rLPOqlInQGXkmmd399clvV6mXgBUEIfLApEg7EAkCDsQCcIORIKwA5Eg7EAkqno+e15tbW2JtSVLlgTHPv/888H6rl27gvW33norsXbFFVcEx6LnSTvuoqkpfEjJxReHTwDt37//KfeUF3t2IBKEHYgEYQciQdiBSBB2IBKEHYhEj5p6e/XVVxNr9957b3DsVVddFayff/75wfq0adMSa9u2bQuOHTBgQLCO6ps6dWqw/vLLLwfraQuiTpgwIVgPTeX27ds3ODYr9uxAJAg7EAnCDkSCsAORIOxAJAg7EAnCDkSiR82z79ixI7HW0NAQHLtx48ZgPe0U1/HjxyfWZs2aFRy7bt26YN2syxV2kdPWrVsTa+vXrw+OveOOO4L1Sy65JFhPO+5j5syZibVVq1YFx2b9fWHPDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJHrUPPvu3bsTayNGjAiOTZuHv+yyy4L1p556KrGWNif7zDPPBOtz584N1pHN4sWLE2tnn312cOyiRYuC9TPOOCNYb29vD9bvu+++xNrkyZODY6dPnx6sJ8kVdjNrkXRUUrukNncv5Hk+AJVTjj37Ne5+uAzPA6CC+JsdiETesLukP5nZu2Y2u6sHmNlsMyuaWbG1tTXn5gBklTfsV7r7WEmTJN1tZt+5qqO7L3X3grsXGhsbc24OQFa5wu7u+0u3hyStl5R8ahiAmsocdjPrZ2ZnffO9pJ9K2lmuxgCUV55P4wdLWl86t/YHkv7b3f+nLF0lCC2jmzZvmtecOXMSa5s2bQqOnTdvXrB+/fXXB+sjR44M1mN1+HB4EmjNmjWJtTvvvDM4Nm0ePU3a+eyhefaWlpZc206SOezuvlfSv5SxFwAVxNQbEAnCDkSCsAORIOxAJAg7EIkedYpr6BK6X331VRU7+bZnn302WA8tNS1Jy5YtC9YXLlx4yj3FYOXKlcH6119/nVi76667yt1O3WPPDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJHrUPHto2eRHHnkkOPaLL74I1s8888xMPUnSgAEDgvUbb7wxWF+9enWwzjx71955551gPbSs8oUXXljuduoee3YgEoQdiARhByJB2IFIEHYgEoQdiARhByLRo+bZx4wZk1hra2sLjt2zZ0/m585r9OjRwfratWsrtu3vs2KxGKxPnDixSp30DOzZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IRI+aZ+/Vq1fmse5exk5QDUeOHAnW05Y2Di2LHKPUPbuZLTezQ2a2s9N9A8xso5l9WLo9p7JtAsirO2/jfyfphpPue0DSJncfIWlT6WcAdSw17O7+tqTPTrr7JkkrSt+vkHRzedsCUG5ZP6Ab7O4HJKl0OyjpgWY228yKZlZsbW3NuDkAeVX803h3X+ruBXcvNDY2VnpzABJkDftBMxsiSaXbQ+VrCUAlZA37BkkzSt/PkPRKedoBUCmp8+xmtlrS1ZIGmtk+Sb+U9JikP5rZLEl/kfSzSjaJODU1NeUan3Ydgdikht3dpyWUflLmXgBUEIfLApEg7EAkCDsQCcIORIKwA5HoUae4Ii7Hjh3LNf70008vUyffD+zZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IBPPsQAUcP34889jevXuXsZN/YM8ORIKwA5Eg7EAkCDsQCcIORIKwA5Eg7EAkopln/+ijj4L1N998M1jfvXt35m1v27Yt81hJmjNnTq7xldSnT59gfcKECYm1a6+9Njj24MGDmXr6xsCBA3ONz6O5uTnz2JEjR5axk39gzw5EgrADkSDsQCQIOxAJwg5EgrADkSDsQCSimWe/7bbbgvW2trZg/bzzzkusnXZa+P/M/fv3B+tp5y9v3LgxWK+lo0ePButPP/105uceNGhQsJ72uu/cuTPzc/fv3z9YT7Nhw4bMY8eNG5dr20lS9+xmttzMDpnZzk73PWxmfzOz7aWvyRXpDkDZdOdt/O8k3dDF/b9y9zGlr9fL2xaAcksNu7u/LemzKvQCoILyfEB3j5m9X3qbf07Sg8xstpkVzazY2tqaY3MA8sga9t9IukDSGEkHJC1KeqC7L3X3grsXGhsbM24OQF6Zwu7uB9293d1PSPqtpPHlbQtAuWUKu5kN6fTjFEnJcxwA6kLqPLuZrZZ0taSBZrZP0i8lXW1mYyS5pBZJZTnh+tFHHw3WFyxYkPm50+YuBw8eHKw3NDQk1j7//PPg2I8//jhYX7NmTbB+yy23BOv1bPPmzYm1SZMmBcceOXIkWD9x4kSwPmXKlMRa6N9TkqZPnx6sT5s2LVhfuHBh5vGhYzrySA27u3fV1bIK9AKggjhcFogEYQciQdiBSBB2IBKEHYiEuXvVNlYoFLxYLCbWZ86cGRz/wgsvJNbGjh2bua9KGz16dLC+fPnyYN3MytlOVbW3tyfWnnzyyeDYdevWBetbtmwJ1qv5u32yfv36Beu7du1KrA0dOjTzdguFgorFYpe/MOzZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IRF1dSnr48OGZx4bm76WePVfdk/Xq1SuxNm/evODYtHraZc527NgRrIekLbM9f/78YD10eq2Uby49K/bsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5Eoq7m2UeNGpV57KpVq4L1tEsDo+dJW2Houuuuy/zcaWM//fTTYP3xxx8P1h966KHE2kUXXRQcmxV7diAShB2IBGEHIkHYgUgQdiAShB2IBGEHIlFX8+xTp04N1q+55prE2v333x8cO3LkyGB9/PjxwTrQ2a233hqspy3ZHFrGu2bz7GZ2rpm9ZWbNZtZkZveW7h9gZhvN7MPS7TkV6RBAWXTnbXybpPvdfZSkCZLuNrOLJT0gaZO7j5C0qfQzgDqVGnZ3P+Du75W+PyqpWdJQSTdJWlF62ApJN1eoRwBlcEof0JnZMEk/lrRF0mB3PyB1/IcgaVDCmNlmVjSzYto1wwBUTrfDbmZnSnpJ0n3u/vfujnP3pe5ecPdC2okLACqnW2E3s97qCPoqd/9mac2DZjakVB8i6VBlWgRQDqlTb9ZxDeZlkprdvfMauxskzZD0WOn2lbzNpF3u+bnnnkuspZ2SePnllwfraafAzp07N7HGtF18jh07lmv8aadV/xCX7syzXynpF5I+MLPtpfvmqyPkfzSzWZL+IulnFekQQFmkht3d/ywpaZf7k/K2A6BSOFwWiARhByJB2IFIEHYgEoQdiERdneKaZsSIEYm15ubm4NgnnngiWF+0aFGwvnLlysTapZdeGhw7c+bMYD1tjp8jD+vPa6+9Fqw3NDQE6+PGjStnO93Cnh2IBGEHIkHYgUgQdiAShB2IBGEHIkHYgUiYu1dtY4VCwYvFYtW2dyq+/PLLYP2ll15KrC1fvjw4dvPmzcF62pzs7bffHqwvWLAgscYcfTZvvPFGsJ52bMTEiROD9XXr1gXrWRUKBRWLxS7PUmXPDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJJhnr4K9e/cG64sXL85V79u3b2LtggsuCI5Nu375qFGjgvW0a+anHUNQSUeOHEmsNTU1Bce++OKLwfrYsWOD9bVr1wbrw4cPD9azYp4dAGEHYkHYgUgQdiAShB2IBGEHIkHYgUikzrOb2bmSfi/pnyWdkLTU3X9tZg9L+g9JraWHznf310PPFes8e15p8/Sha+IfPnw4OPb48ePB+rZt24L1Tz75JFivV3369AnWH3zwwWB9/vz5wXrv3r1PuadyCM2zd2eRiDZJ97v7e2Z2lqR3zWxjqfYrdw+vvgCgLnRnffYDkg6Uvj9qZs2Shla6MQDldUp/s5vZMEk/lrSldNc9Zva+mS03s3MSxsw2s6KZFVtbW7t6CIAq6HbYzexMSS9Jus/d/y7pN5IukDRGHXv+LhdLc/el7l5w9wLXQwNqp1thN7Pe6gj6KndfJ0nuftDd2939hKTfSgqfEQGgplLDbmYmaZmkZnd/stP9Qzo9bIqkneVvD0C5dOfT+Csl/ULSB2a2vXTffEnTzGyMJJfUImlOBfqD0k+HXLJkSZU6+a5Dhw4F6+3t7Ym10Km55RA6vbZfv34V3XY96s6n8X+W1NW8XXBOHUB94Qg6IBKEHYgEYQciQdiBSBB2IBKEHYhEd+bZgUSDBg2qdQvoJvbsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EoqpLNptZq6TO1x4eKCl8rePaqdfe6rUvid6yKmdv57t7l9d/q2rYv7Nxs6K7F2rWQEC99lavfUn0llW1euNtPBAJwg5EotZhX1rj7YfUa2/12pdEb1lVpbea/s0OoHpqvWcHUCWEHYhETcJuZjeY2S4z22NmD9SihyRm1mJmH5jZdjOr6frSpTX0DpnZzk73DTCzjWb2Yem2yzX2atTbw2b2t9Jrt93MJteot3PN7C0zazazJjO7t3R/TV+7QF9Ved2q/je7mfWStFvSv0naJ2mrpGnu/n9VbSSBmbVIKrh7zQ/AMLOrJH0h6ffufmnpvv+S9Jm7P1b6j/Icd//POuntYUlf1HoZ79JqRUM6LzMu6WZJ/64avnaBvm5RFV63WuzZx0va4+573f2YpD9IuqkGfdQ9d39b0mcn3X2TpBWl71eo45el6hJ6qwvufsDd3yt9f1TSN8uM1/S1C/RVFbUI+1BJf+308z7V13rvLulPZvaumc2udTNdGOzuB6SOXx5J9XZdqNRlvKvppGXG6+a1y7L8eV61CHtXS0nV0/zfle4+VtIkSXeX3q6ie7q1jHe1dLHMeF3Iuvx5XrUI+z5J53b6+YeS9tegjy65+/7S7SFJ61V/S1Ef/GYF3dJteGXFKqqnZby7WmZcdfDa1XL581qEfaukEWb2IzNrkPRzSRtq0Md3mFm/0gcnMrN+kn6q+luKeoOkGaXvZ0h6pYa9fEu9LOOdtMy4avza1Xz5c3ev+pekyer4RP4jSQ/VooeEvoZL2lH6aqp1b5JWq+Nt3XF1vCOaJemfJG2S9GHpdkAd9faipA8kva+OYA2pUW//qo4/Dd+XtL30NbnWr12gr6q8bhwuC0SCI+iASBB2IBKEHYgEYQciQdiBSBB2IBKEHYjE/wP8VOMwG0xOiAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "main_dir = 'quickdraw-png_set1/'\n",
    "\n",
    "img = Image.open(os.path.join(main_dir, df.index[1]))\n",
    "img = np.asarray(img, dtype=np.uint8)\n",
    "print(img.shape)\n",
    "plt.imshow(np.array(img), cmap='binary');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implementing a Custom Dataset Class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we implement a custom `Dataset` for reading the images. The `__getitem__` method will\n",
    "\n",
    "1. read a single image from disk based on an `index` (more on batching later)\n",
    "2. perform a custom image transformation (if a `transform` argument is provided in the `__init__` construtor)\n",
    "3. return a single image and it's corresponding label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "class QuickdrawDataset(Dataset):\n",
    "    \"\"\"Custom Dataset for loading Quickdraw images\"\"\"\n",
    "\n",
    "    def __init__(self, txt_path, img_dir, transform=None):\n",
    "    \n",
    "        df = pd.read_csv(txt_path, sep=\",\", index_col=0)\n",
    "        self.img_dir = img_dir\n",
    "        self.txt_path = txt_path\n",
    "        self.img_names = df.index.values\n",
    "        self.y = df['Label'].values\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(os.path.join(self.img_dir,\n",
    "                                      self.img_names[index]))\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        label = self.y[index]\n",
    "        return img, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.y.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have created our custom Dataset class, let us add some custom transformations via the `transforms` utilities from `torchvision`, we\n",
    "\n",
    "1. normalize the images (here: dividing by 255)\n",
    "2. converting the image arrays into PyTorch tensors\n",
    "\n",
    "Then, we initialize a Dataset instance for the training images using the 'quickdraw_png_set1_train.csv' label file (we omit the test set, but the same concepts apply).\n",
    "\n",
    "Finally, we initialize a `DataLoader` that allows us to read from the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note that transforms.ToTensor()\n",
    "# already divides pixels by 255. internally\n",
    "\n",
    "custom_transform = transforms.Compose([#transforms.Lambda(lambda x: x/255.),\n",
    "                                       transforms.ToTensor()])\n",
    "\n",
    "train_dataset = QuickdrawDataset(txt_path='quickdraw_png_set1_train.csv',\n",
    "                                 img_dir='quickdraw-png_set1/',\n",
    "                                 transform=custom_transform)\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset,\n",
    "                          batch_size=128,\n",
    "                          shuffle=True,\n",
    "                          num_workers=4) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it, now we can iterate over an epoch using the train_loader as an iterator and use the features and labels from the training dataset for model training:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Iterating Through the Custom Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | Batch index: 0 | Batch size: 128\n",
      "Epoch: 2 | Batch index: 0 | Batch size: 128\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.manual_seed(0)\n",
    "\n",
    "num_epochs = 2\n",
    "for epoch in range(num_epochs):\n",
    "\n",
    "    for batch_idx, (x, y) in enumerate(train_loader):\n",
    "        \n",
    "        print('Epoch:', epoch+1, end='')\n",
    "        print(' | Batch index:', batch_idx, end='')\n",
    "        print(' | Batch size:', y.size()[0])\n",
    "        \n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just to make sure that the batches are being loaded correctly, let's print out the dimensions of the last batch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 1, 28, 28])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, each batch consists of 128 images, just as specified. However, one thing to keep in mind though is that\n",
    "PyTorch uses a different image layout (which is more efficient when working with CUDA); here, the image axes are \"num_images x channels x height x width\" (NCHW) instead of \"num_images height x width x channels\" (NHWC):"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To visually check that the images that coming of the data loader are intact, let's swap the axes to NHWC and convert an image from a Torch Tensor to a NumPy array so that we can visualize the image via `imshow`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([28, 28, 1])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "one_image = x[0].permute(1, 2, 0)\n",
    "one_image.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQxklEQVR4nO3df2xVZZoH8O9XnBGpw49CkQpIZcQoAivY4CaSiWTYiRIUiQ4OwRGjWfBXdKLRNe4fozEmqDtjJrhO0hlgmKVqMDMKJsalkjFmEjNSoEuLuOCS2kHQtpFfgqLAs3/0OFboec/lnnPvueX5fpLm3t7nvj2PV749t/c957w0M4jIme+svBsQkfJQ2EWcUNhFnFDYRZxQ2EWcOLucGxsxYoTV1dWVc5MirrS3t6O7u5t91VKFneS1AH4DYACA35vZ0tDz6+rq0NzcnGaTIhJQX18fWyv6bTzJAQD+E8B1ACYCWEByYrE/T0RKK83f7NMBfGhmu8zsKwAvA5ibTVsikrU0YR8N4O+9vt8dPfYdJBeTbCbZ3NXVlWJzIpJGmrD39SHAKcfemlmDmdWbWX1NTU2KzYlIGmnCvhvA2F7fjwGwJ107IlIqacK+EcAEkheR/D6AnwFYl01bIpK1oqfezOwYyfsA/Dd6pt5WmNm2zDoTkUylmmc3szcAvJFRLyJSQjpcVsQJhV3ECYVdxAmFXcQJhV3ECYVdxImyns/en4WuwvvQQw8Fx7744ovB+qBBg4L1oUOHButVVVWxtSVLlgTH3nrrrcG6nDm0ZxdxQmEXcUJhF3FCYRdxQmEXcUJhF3FCU28FamxsjK0999xzwbE333xzsF5dXR2s79u3L1h/++23Y2tr1qwJjtXUmx/as4s4obCLOKGwizihsIs4obCLOKGwizihsIs4oXn2Am3cuDG2ljRP/sorr2TdznfMnj07tpY0Ry9+aM8u4oTCLuKEwi7ihMIu4oTCLuKEwi7ihMIu4oSbefYtW7YE6wcPHgzWW1tbY2uTJk0qqqeshOb5Ozo6ytiJVLJUYSfZDuAQgOMAjplZfRZNiUj2stizzzSz7gx+joiUkP5mF3EibdgNwHqSm0gu7usJJBeTbCbZ3NXVlXJzIlKstGG/2symAbgOwL0kf3TyE8yswczqzay+pqYm5eZEpFipwm5me6LbTgCvApieRVMikr2iw06yiuQPvrkP4CcA2rJqTESylebT+PMBvErym5/zopm9mUlXRdi6dWuwPm3atJJtO3oNYg0fPjxYHzJkSLCetGRzZ2dnbK27OzxRknTd+MGDBwfraXpPGpu07TT1tD972LBhwXrS/7OkfzOlUHTYzWwXgH/KsBcRKSFNvYk4obCLOKGwizihsIs4obCLOHHGnOJ66aWXpho/b968YP21116Lrc2ZMyc4dsKECcH6oUOHgvX9+/cH68ePH4+t7dmzJzj2vffeC9a/+OKLYD3p1OCk+plq4cKFwfrq1avL1Mm3tGcXcUJhF3FCYRdxQmEXcUJhF3FCYRdxQmEXceKMmWc/evRoqvEzZ84M1pcuXRpbGz9+fHDs2Wfn9zIfO3YsWM+zt6TlpNPO4R84cCC2lnRsQ9LPXr58ebC+fv36YD0P2rOLOKGwizihsIs4obCLOKGwizihsIs4obCLOHHGzLMfOXIk1fiBAwcG65dcckmqn5+XPOfRkyRdjjmpnqfDhw8H601NTcF6e3t7bK2urq6IjpJpzy7ihMIu4oTCLuKEwi7ihMIu4oTCLuKEwi7iROVOwp6mDz74INX4/jqPntbnn38erCdds76qqipYDy19PGDAgODYSjZ9+vRU40PX689tnp3kCpKdJNt6PVZNsonkzui2co9+EBEAhb2N/wOAa0967FEAG8xsAoAN0fciUsESw25m7wD47KSH5wJYFd1fBeDGbNsSkawV+wHd+Wa2FwCi25FxTyS5mGQzyeaurq4iNyciaZX803gzazCzejOrr6mpKfXmRCRGsWH/lGQtAES3ndm1JCKlUGzY1wFYFN1fBGBtNu2ISKkkzrOTfAnANQBGkNwN4JcAlgJYQ/JOAB0AflrKJguxZcuWVOOnTJmSUSfZO3HiRLC+dm3879oXXnghOPatt94qqqdCzZo1K7a2cuXK4NgxY8Zk3U5mLrvssmA96fiDzZs3x9bmz59fVE9JEsNuZgtiSj/OuBcRKSEdLivihMIu4oTCLuKEwi7ihMIu4sQZc4pr0tTbRRddFKznednipOWBb7nllmD9zTffjK1dfPHFwbFPPfVUsF5dXR2sJy1N/Prrr8fWxo0bFxyb9N+9bNmyYH348OHBehpJp+decMEFwXpnZ/mPQ9OeXcQJhV3ECYVdxAmFXcQJhV3ECYVdxAmFXcQJN/PsU6dOLVMnp+ro6AjW58yZE6zv2LEjWF+xYkVsbdGiRbE1ADjrrHS/7++6665g/eOPP46tNTQ0BMc+88wzwfquXbuC9XfffTe2RjI4Nq2k4xM+++zkyzqWnvbsIk4o7CJOKOwiTijsIk4o7CJOKOwiTijsIk70q3n2o0ePxtaSlmwu1eV5AcDMgvWZM2cG60nnsydd7nnGjBnBep5Gjx4dW3viiSeCY5OWLr7jjjuC9ba2ttja5MmTg2PT0jy7iORGYRdxQmEXcUJhF3FCYRdxQmEXcUJhF3GiX82zt7a2xta+/vrr4Nhp06Zl3c4/tLS0BOtJ5103NjYG65U8j15Ko0aNSjX+2LFjGXVy+pLm2ZOucVAKiXt2kitIdpJs6/XY4yQ/JtkSfc0ubZsiklYhb+P/AODaPh5/zsyuiL7eyLYtEclaYtjN7B0A5T+2T0QyleYDuvtIbo3e5sculEZyMclmks1dXV0pNiciaRQb9t8C+CGAKwDsBfCruCeaWYOZ1ZtZfU1NTZGbE5G0igq7mX1qZsfN7ASA3wGYnm1bIpK1osJOsrbXt/MAxJ9LKCIVIXGeneRLAK4BMILkbgC/BHANySsAGIB2AEtK1+K3kq4NH1LK68YnzakOGTIkWE9aZ/ymm24K1s8555xgvb/68ssvU40fOHBgRp2cvko8nz0x7Ga2oI+Hl5egFxEpIR0uK+KEwi7ihMIu4oTCLuKEwi7iRL86xTU09TZy5Mjg2Nra2mA9jXHjxgXrq1evDtZvuOGGYP22224L1leuXBlbGzRoUHBsJevPU2/DhsUeQQ5Al5IWkRJS2EWcUNhFnFDYRZxQ2EWcUNhFnFDYRZw4Y+bZS3mp6LTmzJkTrD///PPB+gMPPBCshy6x/fLLLwfHTpkyJVjPU2iJ7kLkeepv0jx76L/tyJEjwbHFHjuhPbuIEwq7iBMKu4gTCruIEwq7iBMKu4gTCruIExU1z378+PFgPTSffP/992fdTtncc889wXrSXPjChQtja1dddVVw7JNPPhmsJ/VWyvPl+/P57EmXkg5JOtdd8+wiEqSwizihsIs4obCLOKGwizihsIs4obCLOFFR8+w7duwI1g8fPhxbK+WSzHmbMWNGsN7S0hJbW7IkvJr2ww8/HKw//fTTwfrtt98erI8dOza2ljQP3tTUFKwnOffcc1ONT6OU8+xjxowp6ucm7tlJjiX5F5LbSW4j+UD0eDXJJpI7o9vw2foikqtC3sYfA/CQmV0G4J8B3EtyIoBHAWwwswkANkTfi0iFSgy7me01s83R/UMAtgMYDWAugFXR01YBuLFEPYpIBk7rAzqSdQCmAvgbgPPNbC/Q8wsBQJ+LrZFcTLKZZHNXV1fKdkWkWAWHneR5AP4E4BdmdrDQcWbWYGb1ZlZfU1NTTI8ikoGCwk7ye+gJeqOZ/Tl6+FOStVG9FkBnaVoUkSwkTr2RJIDlALab2a97ldYBWARgaXS7Nm0zoUtFJ0k6DTRpOiPNVEneQpctXrNmTXDspk2bgvVnn302WF+2bFmwnvZy0CEPPvhgsN5fp9727duXYSffKmSe/WoAPwfQSrIleuwx9IR8Dck7AXQA+GlJOhSRTCSG3cz+CoAx5R9n246IlIoOlxVxQmEXcUJhF3FCYRdxQmEXcaKiTnE9++xwO5dffnlsbd26dcGxjzzySLA+cmSfR/sWtO1JkyYFxybVJ06cmGr80KFDg/WQK6+8MlhPWvI5Sei05K+++io4dsCAAcH64MGDi+qpHEp5imuxtGcXcUJhF3FCYRdxQmEXcUJhF3FCYRdxQmEXcaKi5tnnz59fdL27uzs49rzzzgvWQ8tBA8D7778fW2tsbAyOLdW86TdClxYOHR8AAJMnTw7W0x4DMH78+Nha0lx0z6UU+ifNs4tIbhR2EScUdhEnFHYRJxR2EScUdhEnFHYRJypqnj2NESNGBOt33313mTo51SeffBKst7W1Bevbtm0rup50/EBDQ0OwfvBgwYv/ZC7pPP2keuh6+sOHDy96LJA8j15VVRWsh+zfv7/osSHas4s4obCLOKGwizihsIs4obCLOKGwizihsIs4Ucj67GMB/BHAKAAnADSY2W9IPg7gXwF0RU99zMzeKFWj/dmoUaNS1WfNmpVlO6elo6MjWA+d5w8AH330UWwtaR3ypHrSed+hetLP3rlzZ7CeNP7AgQPB+tSpU2Nr119/fXBssQo5qOYYgIfMbDPJHwDYRLIpqj1nZv9Rks5EJFOFrM++F8De6P4hktsBjC51YyKSrdP6m51kHYCpAP4WPXQfya0kV5Ds8/hCkotJNpNs7urq6uspIlIGBYed5HkA/gTgF2Z2EMBvAfwQwBXo2fP/qq9xZtZgZvVmVl9TU5O+YxEpSkFhJ/k99AS90cz+DABm9qmZHTezEwB+B2B66doUkbQSw86eS3wuB7DdzH7d6/HaXk+bByB86paI5KqQT+OvBvBzAK0kW6LHHgOwgOQVAAxAO4AlJehPcnbhhRemqkvlKOTT+L8C6OsC3ppTF+lHdASdiBMKu4gTCruIEwq7iBMKu4gTCruIEwq7iBMKu4gTCruIEwq7iBMKu4gTCruIEwq7iBMKu4gTNLPybYzsAtD72sIjAHSXrYHTU6m9VWpfgHorVpa9jTOzPq//Vtawn7JxstnM6nNrIKBSe6vUvgD1Vqxy9aa38SJOKOwiTuQd9oactx9Sqb1Val+AeitWWXrL9W92ESmfvPfsIlImCruIE7mEneS1JP+X5IckH82jhzgk20m2kmwh2ZxzLytIdpJs6/VYNckmkjuj2z7X2Mupt8dJfhy9di0kZ+fU21iSfyG5neQ2kg9Ej+f62gX6KsvrVva/2UkOALADwL8A2A1gI4AFZhZe6LtMSLYDqDez3A/AIPkjAJ8D+KOZTYoeewbAZ2a2NPpFOczM/q1CenscwOd5L+MdrVZU23uZcQA3ArgdOb52gb7mowyvWx579ukAPjSzXWb2FYCXAczNoY+KZ2bvAPjspIfnAlgV3V+Fnn8sZRfTW0Uws71mtjm6fwjAN8uM5/raBfoqizzCPhrA33t9vxuVtd67AVhPchPJxXk304fzzWwv0POPB8DInPs5WeIy3uV00jLjFfPaFbP8eVp5hL2vpaQqaf7vajObBuA6APdGb1elMAUt410ufSwzXhGKXf48rTzCvhvA2F7fjwGwJ4c++mRme6LbTgCvovKWov70mxV0o9vOnPv5h0paxruvZcZRAa9dnsuf5xH2jQAmkLyI5PcB/AzAuhz6OAXJquiDE5CsAvATVN5S1OsALIruLwKwNsdevqNSlvGOW2YcOb92uS9/bmZl/wIwGz2fyP8fgH/Po4eYvsYD+J/oa1vevQF4CT1v675GzzuiOwEMB7ABwM7otrqCevsvAK0AtqInWLU59TYDPX8abgXQEn3Nzvu1C/RVltdNh8uKOKEj6EScUNhFnFDYRZxQ2EWcUNhFnFDYRZxQ2EWc+H9MFRFikCzOSwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# note that imshow also works fine with scaled\n",
    "# images in [0, 1] range.\n",
    "plt.imshow(one_image.to(torch.device('cpu')).squeeze(), cmap='binary');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch      : 1.10.1\n",
      "PIL        : 9.0.1\n",
      "sys        : 3.8.8 | packaged by conda-forge | (default, Feb 20 2021, 16:22:27) \n",
      "[GCC 9.3.0]\n",
      "torchvision: 0.11.2\n",
      "pandas     : 1.2.5\n",
      "numpy      : 1.22.2\n",
      "matplotlib : 3.3.4\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
